/**
 * Copyright 2024 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "c/ddk/graph/tensor.h"
#include "resource_manager.h"
#include "framework/infra/log/log.h"
#include "graph/tensor.h"

static const std::map<HIAI_DataType, ge::DataType> DATA_TYPE_MAP = {
    {HIAI_DATATYPE_UINT8, ge::DataType::DT_UINT8},
    {HIAI_DATATYPE_FLOAT32, ge::DataType::DT_FLOAT},
    {HIAI_DATATYPE_FLOAT16, ge::DataType::DT_FLOAT16},
    {HIAI_DATATYPE_INT32, ge::DataType::DT_INT32},
    {HIAI_DATATYPE_INT8, ge::DataType::DT_INT8},
    {HIAI_DATATYPE_INT16, ge::DataType::DT_INT16},
    {HIAI_DATATYPE_BOOL, ge::DataType::DT_BOOL},
    {HIAI_DATATYPE_INT64, ge::DataType::DT_INT64},
    {HIAI_DATATYPE_UINT32, ge::DataType::DT_UINT32},
    {HIAI_DATATYPE_DOUBLE, ge::DataType::DT_DOUBLE}
};

TensorHandle HIAI_IR_CreateTensor(ResMgrHandle resMgr, void* data, size_t dataLen,
    const int64_t shape[], size_t shapeSize, HIAI_DataType typeC, HIAI_Format formatC)
{
    if (resMgr == nullptr || data == nullptr || shape == nullptr) {
        FMK_LOGE("Input Handle [resMgr] or [data] or [shape] is nullptr.");
        return nullptr;
    }
    // 创建c++ Tensor对象
    BasePtr tensorBasePtr = std::make_shared<ge::Tensor>();
    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    std::vector<int64_t> shapeVec(shape, shape + shapeSize);

    auto iter = DATA_TYPE_MAP.find(typeC);
    if (iter == DATA_TYPE_MAP.end()) {
        FMK_LOGE("typeC %d is not support now", static_cast<int>(typeC));
        return nullptr;
    }
    ge::TensorDesc desc(ge::Shape(shapeVec), static_cast<ge::Format>(formatC), iter->second);

    // 从BasePtr转成tensorPtr，调用相应c++ set设置数据
    std::shared_ptr<ge::Tensor> tensorPtr = std::dynamic_pointer_cast<ge::Tensor>(tensorBasePtr);
    tensorPtr->SetTensorDesc(desc);
    tensorPtr->SetData(static_cast<uint8_t*>(data), dataLen);
    // 资源存储到ResourceManager
    resMgrPtr->StoreSrcPtr(tensorPtr);
    // 返回裸指针
    return tensorPtr.get();
}